#!/bin/bash

# This script is used to compare the performance/accurate of ONNX and TensorRT models.

# inspect the model
polygraphy inspect model ../output/onnx_output/chatglm_6b.onnx > result-inspect.log 2>&1

#polygraphy run ../output/onnx_output/chatglm_6b.onnx \
#    --onnxrt --trt \
#    --workspace 1000000000 \
#    --save-engine=model-FP32-MarkAll.plan \
#    --atol 1e-3 --rtol 1e-3 \
#    --verbose \
#    --onnx-outputs mark all \
#    --trt-outputs mark all \
#    --trt-min-shapes 'input_ids:[1,1] position_ids:[1,2,1] attention_mask:[1,1,1,1]' \
#    --trt-opt-shapes 'input_ids:[1,512] position_ids:[1,2,512] attention_mask:[1,1,512,512]' \
#    --trt-max-shapes 'input_ids:[1,1024] position_ids:[1,2,1024] attention_mask:[1,1,1024,1024]' \
#    --input-shapes   'input_ids:[1,512] position_ids:[1,2,512] attention_mask:[1,1,512,512]'
#    > result-run-FP32-MarkAll-profile1.log 2>&1

# polygraphy run ../output/onnx_output/chatglm_6b.onnx \
#     --onnxrt --trt \
#     --workspace 1000000000 \
#     --save-engine=model-FP32-MarkAll.plan \
#     --atol 1e-3 --rtol 1e-3 \
#     --verbose \
#     --onnx-outputs mark all \
#     --trt-outputs mark all \
#     --trt-min-shapes 'input_ids:[1,1] position_ids:[1,2,1] attention_mask:[1,1,1,1] past_key_values.1.decorder.key:[0,1,32,128] past_key_values.1.decorder.value:[0,1,32,128] past_key_values.2.decorder.key:[0,1,32,128] past_key_values.2.decorder.value:[0,1,32,128] past_key_values.3.decorder.key:[0,1,32,128] past_key_values.3.decorder.value:[0,1,32,128] past_key_values.4.decorder.key:[0,1,32,128] past_key_values.4.decorder.value:[0,1,32,128] past_key_values.5.decorder.key:[0,1,32,128] past_key_values.5.decorder.value:[0,1,32,128] past_key_values.6.decorder.key:[0,1,32,128] past_key_values.6.decorder.value:[0,1,32,128] past_key_values.7.decorder.key:[0,1,32,128] past_key_values.7.decorder.value:[0,1,32,128] past_key_values.8.decorder.key:[0,1,32,128] past_key_values.8.decorder.value:[0,1,32,128] past_key_values.9.decorder.key:[0,1,32,128] past_key_values.9.decorder.value:[0,1,32,128] past_key_values.10.decorder.key:[0,1,32,128] past_key_values.10.decorder.value:[0,1,32,128] past_key_values.11.decorder.key:[0,1,32,128] past_key_values.11.decorder.value:[0,1,32,128] past_key_values.12.decorder.key:[0,1,32,128] past_key_values.12.decorder.value:[0,1,32,128] past_key_values.13.decorder.key:[0,1,32,128] past_key_values.13.decorder.value:[0,1,32,128] past_key_values.14.decorder.key:[0,1,32,128] past_key_values.14.decorder.value:[0,1,32,128] past_key_values.15.decorder.key:[0,1,32,128] past_key_values.15.decorder.value:[0,1,32,128] past_key_values.16.decorder.key:[0,1,32,128] past_key_values.16.decorder.value:[0,1,32,128] past_key_values.17.decorder.key:[0,1,32,128] past_key_values.17.decorder.value:[0,1,32,128] past_key_values.18.decorder.key:[0,1,32,128] past_key_values.18.decorder.value:[0,1,32,128] past_key_values.19.decorder.key:[0,1,32,128] past_key_values.19.decorder.value:[0,1,32,128] past_key_values.20.decorder.key:[0,1,32,128] past_key_values.20.decorder.value:[0,1,32,128] past_key_values.21.decorder.key:[0,1,32,128] past_key_values.21.decorder.value:[0,1,32,128] past_key_values.22.decorder.key:[0,1,32,128] past_key_values.22.decorder.value:[0,1,32,128] past_key_values.23.decorder.key:[0,1,32,128] past_key_values.23.decorder.value:[0,1,32,128] past_key_values.24.decorder.key:[0,1,32,128] past_key_values.24.decorder.value:[0,1,32,128] past_key_values.25.decorder.key:[0,1,32,128] past_key_values.25.decorder.value:[0,1,32,128] past_key_values.26.decorder.key:[0,1,32,128] past_key_values.26.decorder.value:[0,1,32,128] past_key_values.27.decorder.key:[0,1,32,128] past_key_values.27.decorder.value:[0,1,32,128] past_key_values.28.decorder.key:[0,1,32,128] past_key_values.28.decorder.value:[0,1,32,128]' \
#     --trt-opt-shapes 'input_ids:[1,512] position_ids:[1,2,512] attention_mask:[1,1,512,512] past_key_values.1.decorder.key:[0,1,32,128] past_key_values.1.decorder.value:[0,1,32,128] past_key_values.2.decorder.key:[0,1,32,128] past_key_values.2.decorder.value:[0,1,32,128] past_key_values.3.decorder.key:[0,1,32,128] past_key_values.3.decorder.value:[0,1,32,128] past_key_values.4.decorder.key:[0,1,32,128] past_key_values.4.decorder.value:[0,1,32,128] past_key_values.5.decorder.key:[0,1,32,128] past_key_values.5.decorder.value:[0,1,32,128] past_key_values.6.decorder.key:[0,1,32,128] past_key_values.6.decorder.value:[0,1,32,128] past_key_values.7.decorder.key:[0,1,32,128] past_key_values.7.decorder.value:[0,1,32,128] past_key_values.8.decorder.key:[0,1,32,128] past_key_values.8.decorder.value:[0,1,32,128] past_key_values.9.decorder.key:[0,1,32,128] past_key_values.9.decorder.value:[0,1,32,128] past_key_values.10.decorder.key:[0,1,32,128] past_key_values.10.decorder.value:[0,1,32,128] past_key_values.11.decorder.key:[0,1,32,128] past_key_values.11.decorder.value:[0,1,32,128] past_key_values.12.decorder.key:[0,1,32,128] past_key_values.12.decorder.value:[0,1,32,128] past_key_values.13.decorder.key:[0,1,32,128] past_key_values.13.decorder.value:[0,1,32,128] past_key_values.14.decorder.key:[0,1,32,128] past_key_values.14.decorder.value:[0,1,32,128] past_key_values.15.decorder.key:[0,1,32,128] past_key_values.15.decorder.value:[0,1,32,128] past_key_values.16.decorder.key:[0,1,32,128] past_key_values.16.decorder.value:[0,1,32,128] past_key_values.17.decorder.key:[0,1,32,128] past_key_values.17.decorder.value:[0,1,32,128] past_key_values.18.decorder.key:[0,1,32,128] past_key_values.18.decorder.value:[0,1,32,128] past_key_values.19.decorder.key:[0,1,32,128] past_key_values.19.decorder.value:[0,1,32,128] past_key_values.20.decorder.key:[0,1,32,128] past_key_values.20.decorder.value:[0,1,32,128] past_key_values.21.decorder.key:[0,1,32,128] past_key_values.21.decorder.value:[0,1,32,128] past_key_values.22.decorder.key:[0,1,32,128] past_key_values.22.decorder.value:[0,1,32,128] past_key_values.23.decorder.key:[0,1,32,128] past_key_values.23.decorder.value:[0,1,32,128] past_key_values.24.decorder.key:[0,1,32,128] past_key_values.24.decorder.value:[0,1,32,128] past_key_values.25.decorder.key:[0,1,32,128] past_key_values.25.decorder.value:[0,1,32,128] past_key_values.26.decorder.key:[0,1,32,128] past_key_values.26.decorder.value:[0,1,32,128] past_key_values.27.decorder.key:[0,1,32,128] past_key_values.27.decorder.value:[0,1,32,128] past_key_values.28.decorder.key:[0,1,32,128] past_key_values.28.decorder.value:[0,1,32,128]' \
#     --trt-max-shapes 'input_ids:[1,1024] position_ids:[1,2,1024] attention_mask:[1,1,1024,1024] past_key_values.1.decorder.key:[0,1,32,128] past_key_values.1.decorder.value:[0,1,32,128] past_key_values.2.decorder.key:[0,1,32,128] past_key_values.2.decorder.value:[0,1,32,128] past_key_values.3.decorder.key:[0,1,32,128] past_key_values.3.decorder.value:[0,1,32,128] past_key_values.4.decorder.key:[0,1,32,128] past_key_values.4.decorder.value:[0,1,32,128] past_key_values.5.decorder.key:[0,1,32,128] past_key_values.5.decorder.value:[0,1,32,128] past_key_values.6.decorder.key:[0,1,32,128] past_key_values.6.decorder.value:[0,1,32,128] past_key_values.7.decorder.key:[0,1,32,128] past_key_values.7.decorder.value:[0,1,32,128] past_key_values.8.decorder.key:[0,1,32,128] past_key_values.8.decorder.value:[0,1,32,128] past_key_values.9.decorder.key:[0,1,32,128] past_key_values.9.decorder.value:[0,1,32,128] past_key_values.10.decorder.key:[0,1,32,128] past_key_values.10.decorder.value:[0,1,32,128] past_key_values.11.decorder.key:[0,1,32,128] past_key_values.11.decorder.value:[0,1,32,128] past_key_values.12.decorder.key:[0,1,32,128] past_key_values.12.decorder.value:[0,1,32,128] past_key_values.13.decorder.key:[0,1,32,128] past_key_values.13.decorder.value:[0,1,32,128] past_key_values.14.decorder.key:[0,1,32,128] past_key_values.14.decorder.value:[0,1,32,128] past_key_values.15.decorder.key:[0,1,32,128] past_key_values.15.decorder.value:[0,1,32,128] past_key_values.16.decorder.key:[0,1,32,128] past_key_values.16.decorder.value:[0,1,32,128] past_key_values.17.decorder.key:[0,1,32,128] past_key_values.17.decorder.value:[0,1,32,128] past_key_values.18.decorder.key:[0,1,32,128] past_key_values.18.decorder.value:[0,1,32,128] past_key_values.19.decorder.key:[0,1,32,128] past_key_values.19.decorder.value:[0,1,32,128] past_key_values.20.decorder.key:[0,1,32,128] past_key_values.20.decorder.value:[0,1,32,128] past_key_values.21.decorder.key:[0,1,32,128] past_key_values.21.decorder.value:[0,1,32,128] past_key_values.22.decorder.key:[0,1,32,128] past_key_values.22.decorder.value:[0,1,32,128] past_key_values.23.decorder.key:[0,1,32,128] past_key_values.23.decorder.value:[0,1,32,128] past_key_values.24.decorder.key:[0,1,32,128] past_key_values.24.decorder.value:[0,1,32,128] past_key_values.25.decorder.key:[0,1,32,128] past_key_values.25.decorder.value:[0,1,32,128] past_key_values.26.decorder.key:[0,1,32,128] past_key_values.26.decorder.value:[0,1,32,128] past_key_values.27.decorder.key:[0,1,32,128] past_key_values.27.decorder.value:[0,1,32,128] past_key_values.28.decorder.key:[0,1,32,128] past_key_values.28.decorder.value:[0,1,32,128]' 

export onnx_model_path=../output/onnx_output/chatglm_6b.onnx
polygraphy run ${onnx_model_path} \
    --onnxrt \
    --trt \
    --workspace 1000000000 \
    --save-engine=../models/model-FP32-MarkAll.plan \
    --atol 1e-3 --rtol 1e-3 \
    --verbose \
    --onnx-outputs mark all \
    --trt-outputs mark all \
    --trt-opt-shapes 'input_ids:[1,512] position_ids:[1,2,512] attention_mask:[1,1,512,512] past_key_values.0.decorder.key:[0,1,32,128] past_key_values.0.decorder.value:[0,1,32,128] ' \
    --input-shapes 'input_ids:[1,512] position_ids:[1,2,512] attention_mask:[1,1,512,512] past_key_values.0.decorder.key:[0,1,32,128] past_key_values.0.decorder.value:[0,1,32,128]' \
    --gen-script compare.py
     
    # > result-run-FP32-MarkAll-profile1.log 2>&1